# https://hrl.boyuai.com/chapter/2/sac%E7%AE%97%E6%B3%95/
# 最大熵强化学习：通过控制策略所采取动作的熵来调整探索与利用的平衡
# https://github.com/thu-ml/tianshou/blob/master/tianshou/policy/modelfree/discrete_sac.py
import random
import gym
import numpy as np
import torch
import torch.nn.functional as F
import rl_utils


class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)


class QValueNet(torch.nn.Module):
    ''' 只有一层隐藏层的Q网络 '''
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(QValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)
    
    
class SAC:
    ''' 处理离散动作的SAC算法 '''
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,alpha_lr, target_entropy, para_soft_update, discount_factor, device):
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device) # 策略网络
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)
        
        self.critic_1 = QValueNet(state_dim, hidden_dim, action_dim).to(device) # 第一个Q网络
        self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),lr=critic_lr)
        self.target_critic_1 = QValueNet(state_dim, hidden_dim,action_dim).to(device)  # 第一个目标Q网络
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        
        self.critic_2 = QValueNet(state_dim, hidden_dim, action_dim).to(device) # 第二个Q网络
        self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),lr=critic_lr)
        self.target_critic_2 = QValueNet(state_dim, hidden_dim,action_dim).to(device)  # 第二个目标Q网络
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())
        
        # 使用alpha的log值,可以使训练结果比较稳定
        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
        self.log_alpha.requires_grad = True  # 可以对alpha求梯度
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr)
        self.target_entropy = target_entropy  # 目标熵的大小
        self.discount_factor = discount_factor
        self.para_soft_update = para_soft_update
        self.device = device
        
        # 使用alpha的log值,可以使训练结果比较稳定
        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
        self.log_alpha.requires_grad = True  # 可以对alpha求梯度
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],lr=alpha_lr)
        self.target_entropy = target_entropy  # 目标熵的大小

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    # 计算目标Q值,直接用策略网络的输出概率进行期望计算
    def calc_target(self, rewards, next_states, dones):
        next_probs = self.actor(next_states)
        next_log_probs = torch.log(next_probs + 1e-8)
        entropy = -torch.sum(next_probs * next_log_probs, dim=1, keepdim=True)
        # 挑选一个Q值小的网络
        q1_value = self.target_critic_1(next_states)
        q2_value = self.target_critic_2(next_states)
        min_qvalue = torch.sum(next_probs * torch.min(q1_value, q2_value),dim=1,keepdim=True)
        next_value = min_qvalue + self.log_alpha.exp() * entropy
        td_target = rewards + self.discount_factor * next_value * (1 - dones)
        return td_target

    def soft_update(self, net, target_net):
        for param_target, param in zip(target_net.parameters(),net.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - self.para_soft_update) +param.data * self.para_soft_update)

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)  # 动作不再是float类型
        rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)
        actions = actions.type(torch.long)

        '''更新2个Q网络'''
        td_target = self.calc_target(rewards, next_states, dones)
        
        critic_1_q_values = self.critic_1(states).gather(1, actions)
        critic_1_loss = torch.mean(F.mse_loss(critic_1_q_values, td_target.detach()))
        self.critic_1_optimizer.zero_grad()
        critic_1_loss.backward()
        self.critic_1_optimizer.step()
        
        critic_2_q_values = self.critic_2(states).gather(1, actions)
        critic_2_loss = torch.mean(F.mse_loss(critic_2_q_values, td_target.detach()))
        self.critic_2_optimizer.zero_grad()
        critic_2_loss.backward()
        self.critic_2_optimizer.step()

        '''更新策略网络'''
        probs = self.actor(states)
        log_probs = torch.log(probs + 1e-8)
        # 直接根据概率计算熵
        entropy = -torch.sum(probs * log_probs, dim=1, keepdim=True)  #
        q1_value = self.critic_1(states)
        q2_value = self.critic_2(states)
        min_qvalue = torch.sum(probs * torch.min(q1_value, q2_value),dim=1,keepdim=True)  # 直接根据概率计算期望
        actor_loss = torch.mean(-self.log_alpha.exp() * entropy - min_qvalue)
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        '''更新熵正则项的系数 alpha'''
        alpha_loss = torch.mean((entropy - target_entropy).detach() * self.log_alpha.exp())
        self.log_alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

        '''软更新'''
        self.soft_update(self.critic_1, self.target_critic_1)
        self.soft_update(self.critic_2, self.target_critic_2)


actor_lr = 1e-3
critic_lr = 1e-2
alpha_lr = 1e-2
num_episodes = 200
hidden_dim = 128
discount_factor = 0.98
para_soft_update = 0.005  # 软更新参数
buffer_size = 10000
minimal_size = 500
batch_size = 64
target_entropy = -1
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

env_name = 'CartPole-v1'
env = gym.make(env_name)
seed2024=0
random.seed(seed2024)
np.random.seed(seed2024)
env.reset(seed=seed2024)
torch.manual_seed(seed2024)
replay_buffer = rl_utils.ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n


agent = SAC(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, alpha_lr,target_entropy, para_soft_update, discount_factor, device)
alg_name = 'SAC'


print('Training!!!!')
return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,replay_buffer, minimal_size,batch_size)
rl_utils.plot_results(return_list, env_name, alg_name, string_train_test = 'Training', moving_average_weight = 9)


print('Testing!!!!')
return_list_test = rl_utils.test_agent(env, agent, num_episodes = 50)
rl_utils.plot_results(return_list_test, env_name, alg_name, string_train_test = 'Testing', moving_average_weight = 3)
# print('Rendering!!!!')
# rl_utils.test_agent_render(env, agent)